"""
The RPyC protocol 
"""
import sys
import select
import weakref
import itertools
import cPickle as pickle
from threading import Lock
from rpyc.utils.lib import WeakValueDict, RefCountingColl
from rpyc.core import consts, brine, vinegar, netref
from rpyc.core.async import AsyncResult


class PingError(Exception):
    pass

DEFAULT_CONFIG = dict(
    # ATTRIBUTES
    allow_safe_attrs = True,
    allow_exposed_attrs = True,
    allow_public_attrs = False,
    allow_all_attrs = False,
    safe_attrs = set(['__abs__', '__add__', '__and__', '__cmp__', '__contains__', 
        '__delitem__', '__delslice__', '__div__', '__divmod__', '__doc__', 
        '__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__', 
        '__getslice__', '__gt__', '__hash__', '__hex__', '__iadd__', '__iand__',
        '__idiv__', '__ifloordiv__', '__ilshift__', '__imod__', '__imul__', 
        '__index__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__',
        '__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__', 
        '__long__', '__lshift__', '__lt__', '__mod__', '__mul__', '__ne__', 
        '__neg__', '__new__', '__nonzero__', '__oct__', '__or__', '__pos__', 
        '__pow__', '__radd__', '__rand__', '__rdiv__', '__rdivmod__', '__repr__',
        '__rfloordiv__', '__rlshift__', '__rmod__', '__rmul__', '__ror__', 
        '__rpow__', '__rrshift__', '__rshift__', '__rsub__', '__rtruediv__', 
        '__rxor__', '__setitem__', '__setslice__', '__str__', '__sub__', 
        '__truediv__', '__xor__', 'next', '__length_hint__', '__enter__', 
        '__exit__', ]),
    exposed_prefix = "exposed_",
    allow_getattr = True,
    allow_setattr = False,
    allow_delattr = False,
    # EXCEPTIONS
    include_local_traceback = True,
    instantiate_custom_exceptions = False,
    import_custom_exceptions = False,
    instantiate_oldstyle_exceptions = False, # which don't derive from Exception
    propagate_SystemExit_locally = False, # whether to propagate SystemExit locally or to the other party
    # MISC
    allow_pickle = False,
    connid = None,
    credentials = None,
)

_connection_id_generator = itertools.count(1)

class Connection(object):
    """The RPyC connection (also know as the RPyC protocol). 
    * service: the service to expose
    * channel: the channcel over which messages are passed
    * config: this connection's config dict (overriding parameters from the 
      default config dict)
    * _lazy: whether or not to initialize the service with the creation of the
      connection. default is True. if set to False, you will need to call
      _init_service manually later
    """
    def __init__(self, service, channel, config = {}, _lazy = False):
        self._closed = True
        self._config = DEFAULT_CONFIG.copy()
        self._config.update(config)
        if self._config["connid"] is None:
            self._config["connid"] = "conn%d" % (_connection_id_generator.next(),)
        
        self._channel = channel
        self._seqcounter = itertools.count()
        self._recvlock = Lock()
        self._sendlock = Lock()
        self._sync_replies = {}
        self._async_callbacks = {}
        self._local_objects = RefCountingColl()
        self._last_traceback = None
        self._proxy_cache = WeakValueDict()
        self._netref_classes_cache = {}
        self._remote_root = None
        self._local_root = service(weakref.proxy(self))
        if not _lazy:
            self._init_service()
        self._closed = False
    def _init_service(self):
        self._local_root.on_connect()
    
    def __del__(self):
        self.close()
    def __enter__(self):
        return self
    def __exit__(self, t, v, tb):
        self.close()
    def __repr__(self):
        a, b = object.__repr__(self).split(" object ")
        return "%s %r object %s" % (a, self._config["connid"], b)
    
    #
    # IO 
    #
    def _cleanup(self, _anyway = True):
        if self._closed and not _anyway:
            return
        self._closed = True
        self._channel.close()
        self._local_root.on_disconnect()
        self._sync_replies.clear()
        self._async_callbacks.clear()
        self._local_objects.clear()
        self._proxy_cache.clear()
        self._netref_classes_cache.clear()
        self._last_traceback = None
        self._last_traceback = None
        self._remote_root = None
        self._local_root = None
        #self._seqcounter = None
        #self._config.clear()
    def close(self, _catchall = True):
        if self._closed:
            return
        self._closed = True
        try:
            try:
                self._async_request(consts.HANDLE_CLOSE)
            except EOFError: 
                pass
            except Exception:
                if not _catchall:
                    raise
        finally:
            self._cleanup(_anyway = True)
    
    @property
    def closed(self):
        return self._closed
    def fileno(self):
        return self._channel.fileno()
    
    def ping(self, data = "the world is a vampire!" * 20, timeout = 3):
        """assert that the other party is functioning properly"""
        res = self.async_request(consts.HANDLE_PING, data, timeout = timeout)
        if res.value != data:
            raise PingError("echo mismatches sent data")
    
    def _send(self, msg, seq, args):
        data = brine.dump((msg, seq, args))
        self._sendlock.acquire()
        try:
            self._channel.send(data)
        finally:
            self._sendlock.release()
    def _send_request(self, handler, args):
        seq = self._seqcounter.next()
        self._send(consts.MSG_REQUEST, seq, (handler, self._box(args)))
        return seq
    def _send_reply(self, seq, obj):
        self._send(consts.MSG_REPLY, seq, self._box(obj))
    def _send_exception(self, seq, exctype, excval, exctb):
        exc = vinegar.dump(exctype, excval, exctb, 
            include_local_traceback = self._config["include_local_traceback"])
        self._send(consts.MSG_EXCEPTION, seq, exc)
    
    #
    # boxing
    #
    def _box(self, obj):
        """store a local object in such a way that it could be recreated on
        the remote party either by-value or by-reference"""
        if brine.dumpable(obj):
            return consts.LABEL_VALUE, obj
        if type(obj) is tuple:
            return consts.LABEL_TUPLE, tuple(self._box(item) for item in obj)
        elif isinstance(obj, netref.BaseNetref) and obj.____conn__() is self:
            return consts.LABEL_LOCAL_REF, obj.____oid__
        else:
            self._local_objects.add(obj)
            cls = getattr(obj, "__class__", type(obj))
            return consts.LABEL_REMOTE_REF, (id(obj), cls.__name__, cls.__module__)
    
    def _unbox(self, package):
        """recreate a local object representation of the remote object: if the
        object is passed by value, just return it; if the object is passed by
        reference, create a netref to it"""
        label, value = package
        if label == consts.LABEL_VALUE:
            return value
        if label == consts.LABEL_TUPLE:
            return tuple(self._unbox(item) for item in value)
        if label == consts.LABEL_LOCAL_REF:
            return self._local_objects[value]
        if label == consts.LABEL_REMOTE_REF:
            oid, clsname, modname = value
            if oid in self._proxy_cache:
                return self._proxy_cache[oid]
            proxy = self._netref_factory(oid, clsname, modname)
            self._proxy_cache[oid] = proxy
            return proxy
        raise ValueError("invalid label %r" % (label,))
    
    def _netref_factory(self, oid, clsname, modname):
        typeinfo = (clsname, modname)
        if typeinfo in self._netref_classes_cache:
            cls = self._netref_classes_cache[typeinfo]
        elif typeinfo in netref.builtin_classes_cache:
            cls = netref.builtin_classes_cache[typeinfo]
        else:
            info = self.sync_request(consts.HANDLE_INSPECT, oid)
            cls = netref.class_factory(clsname, modname, info)
            self._netref_classes_cache[typeinfo] = cls
        return cls(weakref.ref(self), oid)
    
    #
    # dispatching
    #
    def _dispatch_request(self, seq, raw_args):
        try:
            handler, args = raw_args
            args = self._unbox(args)
            res = self._HANDLERS[handler](self, *args)
        except KeyboardInterrupt:
            raise
        except:
            t, v, tb = sys.exc_info()
            self._last_traceback = tb
            if t is SystemExit and self._config["propagate_SystemExit_locally"]:
                raise
            self._send_exception(seq, t, v, tb)
        else:
            self._send_reply(seq, res)
    
    def _dispatch_reply(self, seq, raw):
        obj = self._unbox(raw)
        if seq in self._async_callbacks:
            self._async_callbacks.pop(seq)(False, obj)
        else:
            self._sync_replies[seq] = (False, obj)
    
    def _dispatch_exception(self, seq, raw):
        obj = vinegar.load(raw, 
            import_custom_exceptions = self._config["import_custom_exceptions"], 
            instantiate_custom_exceptions = self._config["instantiate_custom_exceptions"],
            instantiate_oldstyle_exceptions = self._config["instantiate_oldstyle_exceptions"])
        if seq in self._async_callbacks:
            self._async_callbacks.pop(seq)(True, obj)
        else:
            self._sync_replies[seq] = (True, obj)
    
    #
    # serving
    #
    def _recv(self, timeout, wait_for_lock):
        if not self._recvlock.acquire(wait_for_lock):
            return None
        try:
            try:
                if self._channel.poll(timeout):
                    data = self._channel.recv()
                else:
                    data = None
            except EOFError:
                self.close()
                raise
        finally:
            self._recvlock.release()
        return data
        
    def _dispatch(self, data):
        msg, seq, args = brine.load(data)
        if msg == consts.MSG_REQUEST:
            self._dispatch_request(seq, args)
        elif msg == consts.MSG_REPLY:
            self._dispatch_reply(seq, args)
        elif msg == consts.MSG_EXCEPTION:
            self._dispatch_exception(seq, args)
        else:
            raise ValueError("invalid message type: %r" % (msg,))

    def poll(self, timeout = 0):
        """serve a single transaction, should one arrives in the given 
        interval. note that handling a request/reply may trigger nested 
        requests, which are all part of the transaction.
        
        returns True if one was served, False otherwise"""
        data = self._recv(timeout, wait_for_lock = False)
        if not data:
            return False
        self._dispatch(data)
        return True
    
    def serve(self, timeout = 1):
        """serve a single request or reply that arrives within the given 
        time frame (default is 1 sec). note that the dispatching of a request
        might trigger multiple (nested) requests, thus this function may be 
        reentrant. returns True if a request or reply were received, False 
        otherwise."""
        
        data = self._recv(timeout, wait_for_lock = True)
        if not data:
            return False
        self._dispatch(data)
        return True
    
    def serve_all(self):
        """serve all requests and replies while the connection is alive"""
        try:
            try:
                while True:
                    self.serve(0.1)
            except select.error:
                if not self.closed:
                    raise e
            except EOFError:
                pass
        finally:
            self.close()
    
    def poll_all(self, timeout = 0):
        """serve all requests and replies that arrive within the given interval.
        returns True if at least one was served, False otherwise"""
        at_least_once = False
        try:
            while self.poll(timeout):
                at_least_once = True
        except EOFError:
            pass
        return at_least_once
    
    #
    # requests
    #
    def sync_request(self, handler, *args):
        """send a request and wait for the reply to arrive"""
        seq = self._send_request(handler, args)
        while seq not in self._sync_replies:
            self.serve(0.1)
        isexc, obj = self._sync_replies.pop(seq)
        if isexc:
            raise obj
        else:
            return obj
    
    def _async_request(self, handler, args = (), callback = (lambda a, b: None)):
        seq = self._send_request(handler, args)
        self._async_callbacks[seq] = callback
    def async_request(self, handler, *args, **kwargs):
        """send a request and return an AsyncResult object, which will 
        eventually hold the reply"""
        timeout = kwargs.pop("timeout", None)
        if kwargs:
            raise TypeError("got unexpected keyword argument %r" % (kwargs.keys()[0],))
        res = AsyncResult(weakref.proxy(self))
        self._async_request(handler, args, res)
        if timeout is not None:
            res.set_expiry(timeout)
        return res
    
    @property
    def root(self):
        """fetch the root object of the other party"""
        if self._remote_root is None:
            self._remote_root = self.sync_request(consts.HANDLE_GETROOT)
        return self._remote_root
    
    #
    # attribute access
    #
    def _check_attr(self, obj, name):
        if self._config["allow_exposed_attrs"]:
            if name.startswith(self._config["exposed_prefix"]):
                name2 = name
            else:
                name2 = self._config["exposed_prefix"] + name
            if hasattr(obj, name2):
                return name2
        if self._config["allow_all_attrs"]:
            return name
        if self._config["allow_safe_attrs"] and name in self._config["safe_attrs"]:
            return name
        if self._config["allow_public_attrs"] and not name.startswith("_"):
            return name
        return False
    
    def _access_attr(self, oid, name, args, overrider, param, default):
        if type(name) is not str:
            raise TypeError("attr name must be a string")
        obj = self._local_objects[oid]
        accessor = getattr(type(obj), overrider, None)
        if accessor is None:
            name2 = self._check_attr(obj, name)
            if not self._config[param] or not name2:
                raise AttributeError("cannot access %r" % (name,))
            accessor = default
            name = name2
        return accessor(obj, name, *args)
    
    #
    # handlers
    #
    def _handle_ping(self, data):
        return data
    def _handle_close(self):
        self._cleanup()
    def _handle_getroot(self):
        return self._local_root
    def _handle_del(self, oid):
        self._local_objects.decref(oid)
    def _handle_repr(self, oid):
        return repr(self._local_objects[oid])
    def _handle_str(self, oid):
        return str(self._local_objects[oid])
    def _handle_cmp(self, oid, other):
        # cmp() might enter recursive resonance... yet another workaround
        #return cmp(self._local_objects[oid], other)
        obj = self._local_objects[oid]
        try:
            return type(obj).__cmp__(obj, other)
        except TypeError:
            return NotImplemented
    def _handle_hash(self, oid):
        return hash(self._local_objects[oid])
    def _handle_call(self, oid, args, kwargs):
        return self._local_objects[oid](*args, **dict(kwargs))
    def _handle_dir(self, oid):
        return tuple(dir(self._local_objects[oid]))
    def _handle_inspect(self, oid):
        return tuple(netref.inspect_methods(self._local_objects[oid]))
    def _handle_getattr(self, oid, name):
        return self._access_attr(oid, name, (), "_rpyc_getattr", "allow_getattr", getattr)
    def _handle_delattr(self, oid, name):
        return self._access_attr(oid, name, (), "_rpyc_delattr", "allow_delattr", delattr)
    def _handle_setattr(self, oid, name, value):
        return self._access_attr(oid, name, (value,), "_rpyc_setattr", "allow_setattr", setattr)
    def _handle_callattr(self, oid, name, args, kwargs):
        return self._handle_getattr(oid, name)(*args, **dict(kwargs))
    def _handle_pickle(self, oid, proto):
        if not self._config["allow_pickle"]: 
            raise ValueError("pickling is disabled")
        return pickle.dumps(self._local_objects[oid], proto)
    def _handle_buffiter(self, oid, count):
        items = []
        obj = self._local_objects[oid]
        for i in xrange(count):
            try:
                items.append(obj.next())
            except StopIteration:
                break
        return tuple(items)
    
    # collect handlers
    _HANDLERS = {}
    for name, obj in locals().items():
        if name.startswith("_handle_"):
            name2 = "HANDLE_" + name[8:].upper()
            if hasattr(consts, name2):
                _HANDLERS[getattr(consts, name2)] = obj
            else:
                raise NameError("no constant defined for %r", name)
    del name, name2, obj


